{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## 第9章EM算法及其推广-习题\n",
    "### 习题9.1\n",
    "\n",
    "&emsp;&emsp;如例9.1的三硬币模型，假设观测数据不变，试选择不同的处置，例如，$\\pi^{(0)}=0.46,p^{(0)}=0.55,q^{(0)}=0.67$，求模型参数为$\\theta=(\\pi,p,q)$的极大似然估计。  \n",
    "\n",
    "**解答：**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import math\n",
    "\n",
    "class EM:\n",
    "    def __init__(self, prob):\n",
    "        self.pro_A, self.pro_B, self.pro_C = prob\n",
    "        \n",
    "    def pmf(self, i):\n",
    "        pro_1 = self.pro_A * math.pow(self.pro_B, data[i]) \\\n",
    "            * math.pow((1-self.pro_B), 1-data[i])\n",
    "        pro_2 = (1 - self.pro_A) * math.pow(self.pro_C, data[i]) \\\n",
    "            * math.pow((1-self.pro_C), 1-data[i])\n",
    "        return pro_1 / (pro_1 + pro_2)\n",
    "    \n",
    "    def fit(self, data):\n",
    "        print('init prob:{}, {}, {}'.format(self.pro_A, self.pro_B, self.pro_C))\n",
    "        count = len(data)\n",
    "        theta = 1\n",
    "        d = 0\n",
    "        while(theta > 0.00001):\n",
    "            # 迭代阻塞\n",
    "            _pmf = [self.pmf(k) for k in range(count)]\n",
    "            pro_A = 1/ count * sum(_pmf)\n",
    "            pro_B = sum([_pmf[k]*data[k] for k in range(count)]) \\\n",
    "            / sum([_pmf[k] for k in range(count)])\n",
    "            pro_C = sum([(1-_pmf[k])*data[k] for k in range(count)]) \\\n",
    "            / sum([(1-_pmf[k]) for k in range(count)])\n",
    "            d += 1\n",
    "            print('{}  pro_a:{:.4f}, pro_b:{:.4f}, pro_c:{:.4f}'.format(d, pro_A, pro_B, pro_C))\n",
    "            theta = abs(self.pro_A-pro_A) + abs(self.pro_B-pro_B) + abs(self.pro_C-pro_C)\n",
    "            self.pro_A = pro_A\n",
    "            self.pro_B = pro_B\n",
    "            self.pro_C = pro_C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "init prob:0.46, 0.55, 0.67\n",
      "1  pro_a:0.4619, pro_b:0.5346, pro_c:0.6561\n",
      "2  pro_a:0.4619, pro_b:0.5346, pro_c:0.6561\n"
     ]
    }
   ],
   "source": [
    "# 加载数据\n",
    "data=[1,1,0,1,0,0,1,0,1,1]\n",
    "\n",
    "em = EM(prob=[0.46, 0.55, 0.67])\n",
    "f = em.fit(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可见通过两次迭代，参数已经收敛，三个硬币的概率分别为0.4619，0.5346，0.6561"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 习题9.2\n",
    "证明引理9.2。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> **引理9.2：**若$\\tilde{P}_{\\theta}(Z)=P(Z | Y, \\theta)$，则$$F(\\tilde{P}, \\theta)=\\log P(Y|\\theta)$$\n",
    "\n",
    "**证明：**  \n",
    "由$F$函数的定义（**定义9.3**）可得：$$F(\\tilde{P}, \\theta)=E_{\\tilde{P}}[\\log P(Y,Z|\\theta)] + H(\\tilde{P})$$其中，$H(\\tilde{P})=-E_{\\tilde{P}} \\log \\tilde{P}(Z)$  \n",
    "$\\begin{aligned}\n",
    "\\therefore F(\\tilde{P}, \\theta) \n",
    "&= E_{\\tilde{P}}[\\log P(Y,Z|\\theta)] -E_{\\tilde{P}} \\log \\tilde{P}(Z) \\\\\n",
    "&= \\sum_Z \\log P(Y,Z|\\theta) \\tilde{P}_{\\theta}(Z) - \\sum_Z \\log \\tilde{P}(Z) \\cdot \\tilde{P}(Z) \\\\\n",
    "&= \\sum_Z \\log P(Y,Z|\\theta) P(Z|Y,\\theta) -  \\sum_Z \\log P(Z|Y,\\theta) \\cdot P(Z|Y,\\theta) \\\\\n",
    "&= \\sum_Z P(Z|Y,\\theta) \\left[ \\log P(Y,Z|\\theta) - \\log P(Z|Y,\\theta) \\right] \\\\\n",
    "&= \\sum_Z P(Z|Y,\\theta) \\log \\frac{P(Y,Z|\\theta)}{P(Z|Y,\\theta)} \\\\\n",
    "&= \\sum_Z P(Z|Y,\\theta) \\log P(Y|\\theta) \\\\\n",
    "&= \\log P(Y|\\theta) \\sum_Z P(Z|Y,\\theta)\n",
    "\\end{aligned}$  \n",
    "$\\displaystyle \\because \\sum_Z \\tilde{P}_{\\theta}(Z) = P(Z|Y, \\theta) = 1$  \n",
    "$\\therefore F(\\tilde{P}, \\theta) = \\log P(Y|\\theta)$，引理9.2得证。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 习题9.3\n",
    "已知观测数据  \n",
    "-67，-48，6，8，14，16，23，24，28，29，41，49，56，60，75  \n",
    "试估计两个分量的高斯混合模型的5个参数。\n",
    "\n",
    "**解答：**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "labels = [1 1 0 0 0 0 0 0 0 0 0 0 0 0 0]\n"
     ]
    }
   ],
   "source": [
    "from sklearn.mixture import GaussianMixture\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# 初始化观测数据\n",
    "data=np.array([-67, -48, 6, 8, 14, 16, 23, 24, 28, 29, 41, 49, 56, 60, 75]).reshape(-1, 1)\n",
    "\n",
    "# 聚类\n",
    "gmmModel = GaussianMixture(n_components=2)\n",
    "gmmModel.fit(data)\n",
    "labels = gmmModel.predict(data)\n",
    "print(\"labels =\", labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAEWCAYAAACNJFuYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZ00lEQVR4nO3de7hddX3n8fdHoiA3ARMCEiqooEZqoz2iYIexAgMqNT4zjUVF07F9mHpvR0egdFrtlJbaPrZ0es14Q0UpwQu01SpiLU8rXgJELaJC64UIhKMWiGjBA9/5Y63o9qydkwM5+6yz2e/X8+RZe132Wt99cs7+rN9v3VJVSJI06EF9FyBJWnoMB0lSh+EgSeowHCRJHYaDJKnDcJAkdRgOekBK8uEk63vc/k8k+W6S3fqqYbEleUOSd89z2U8k+eVR16T7z3DQgkhyapJPJ7kzya3t65cnSR/1VNWzqur8hV5vkl9MUknePGv689rp72i3/42q2ruq7pnHOt+R5HcWutY5tveMttb3z5r+U+30TyxWLVq6DAftsiSvBc4D/gA4CFgJ/ArwdOAhPZY2Kv8K/EKSZQPTXgJ8pY9iZtUxX9PAsUkePjBtPT19Bi09hoN2SZKHAb8NvLyqLq6qbdW4pqpeVFV3tcs9J8k1Se5IcmOSNwys4xlJtsxa79eSnNC+PjrJpva9W7fvtSfZI8m7k3w7yW1JPptkZTvvh90WSR6d5OPtct9KckGS/WZt63VJPp/k9iR/nWSPOT72LcAXgJPa9x8AHAtcOrDOw9q98GVJDkiyJcnPtfP2TnJDkpckOR14EfD6thvqb9plKsljBtb3w9bF9p9XkjOS3AK8PcmDkpyZ5F/bz3lRW9eO3A18EDi1XeduwPOBC2b9Pxzb/lxvb4fHDsw7PMk/JtmW5DJg+az3Pi3JJ9v/m88lecYc9WiJMRy0q44Bdgcu2clyd9LsXe8HPAd4WZLnzXMb5wHnVdW+wKOBi9rp64GHAYcCD6dprXx/yPsD/B7wCODx7fJvmLXM84GTgcOBJwK/uJOa3tl+Hmi+YC8B7hq2YFV9B3gp8P+SHAj8EbC5qt5ZVRtovpDf1HZD/dxOtrvdQcABwCOB04FXA88D/nP7Of8d+LP78BlOAq4Fbto+sw2XvwP+hObn+2bg7wZaG+8BrqIJhf9D8/+x/b2HtO/9nbbO1wHvS7Jinp9PPTMctKuWA9+qqpntEwb2Fr+f5DiAqvpEVX2hqu6tqs8D76X5IpuPHwCPSbK8qr5bVZ8amP5w4DFVdU9VXVVVd8x+c1XdUFWXVdVdVTVN8yU3e9t/UlU3tV/kfwOs2UlNHwCe0bacXkLzRbtDVfVRYCNwOU04/o+drH9n7gV+q/1M32/Xd3ZVbWlba28Afn6uLqeq+iRwQJLH7uAzPAe4vqreVVUzVfVe4EvAzyX5CeApwP9ua7iC5ue23WnAh6rqQ+3/+WXAJuDZu/i5tUgMB+2qbwPLB7+EqurYqtqvnfcggCRPTfIPSaaT3E6zl7982AqH+CXgSOBLbdfGKe30dwEfAS5MclOSNyV58Ow3JzkwyYVJvpnkDuDdQ7Z9y8Dr7wF7z1VQ+4X8d8BvAMur6p/n8Tk2AEcBb6+qb89j+blMV9V/DIw/EvhAG8q3AdcB99Ac/5nLu4BXAj9LE3iDHgF8fda0rwOHtPP+varunDVvsJ512+tpa/oZ4OCdfTAtDYaDdtWVNN0pa3ey3Hto+uQPraqHAX9J090DTZfTntsXbPu/f9j9UFXXV9ULgAOB3wcuTrJXVf2gqt5YVatp+vxP4UfdJIN+DyjgiW3X1GkD294V7wReS/MFO6f2M/1V+56XDR5PaGub7XsM/ExoupEGzX7PjcCzqmq/gX97VNU3d1Lau4CX0+zlf2/WvJtovuQH/QTwTeBmYP8ke82aN1jPu2bVs1dVnbuTerREGA7aJVV1G/BG4M+T/Hx7sPVBSdYAg18c+wDfqar/SHI08MKBeV8B9mgPWj+YZm989+0zk5yWZEVV3Qvc1k6+J8nPJvnJ9ov3DppupmGnju4DfBe4re0L/1+7/skB+EfgROD/zmPZX2+HLwX+EHhnfnQNxFbgUbOW3wy8MMluSU5m511wfwmck+SRAElWJNlZYFNVX23XffaQ2R8CjkzywvbA+i8Aq4G/raqv03QTvTHJQ5L8DDB4vOTdNN1PJ7WfYY/2QPqqndWkpcFw0C6rqjcB/xN4PXArzZfdXwFnAJ9sF3s58NtJtgG/yY8OKlNVt7fz30KzV3onMHj20snAtUm+S3Nw+tS2S+Ug4GKaYLiO5st62EVYbwSeDNxO0xX0/iHL3GftWVmXt8cpdijJT9P8fF7SXvfw+zR7/me2i7wVWN12v3ywnfYami/b22jOZvogczuPpmX20fZn/CngqfP8HP9UVTcNmf5tmtbYa2m6CF8PnFJV32oXeWG7je8Av8XAMYuqupGmNfnrNKfN3kgTyn7njIn4sB9J0mymuCSpw3CQJHUYDpKkDsNBktRxf27YtWCS/BrwyzRnbnwB+O8053b/NXAY8DXg+VX173OtZ/ny5XXYYYeNslRJesC56qqrvlVVQ29p0tvZSu355v8ErK6q7ye5iOa86tU058Ofm+RMYP+qOmOudU1NTdWmTZtGX7QkPYAkuaqqpobN67tbaRnw0PbWC3vSXJG5Fth+H/7zaW4mJklaRL2FQ3tZ/x8C36C5FP/29uZkK6vq5naZm2lumdCR5PQ0t3HeND09vVhlS9JE6C0ckuxP00o4nOYmXnslOW2+76+qDVU1VVVTK1Z4F2BJWkh9diudAHy1qqar6gc0tzQ4Ftia5GCAdnhrjzVK0kTqMxy+ATwtyZ5JAhxPc3+cS/nRQ0PWs/OHyEiSFlhvp7JW1aeTXAxcDcwA19Dc735v4KIkv0QTIOv6qlGSJlWvZytV1W9V1eOq6qiqenH7RKlvV9XxVXVEO5zzjpeSNJG2boXjjoN9922GW7cu6Or7PpVVknR/rFsHV14J27Y1w3UL28liOEjSONq8GWbaR7fPzDTjC8hwkKRxtGYNLGsPGy9b1owvIMNBksbRxo1wzDGwzz7NcOPGBV19rzfekyTdTytXwhVXjGz1thwkSR2GgySpw3CQJHUYDpKkDsNBktRhOEiSOgwHSVKH4SBJ6jAcJEkdhoMkqcNwkCR1GA6SpA7DQZLUYThIkjp6DYck+yW5OMmXklyX5JgkByS5LMn17XD/PmuUpEnUd8vhPODvq+pxwE8B1wFnApdX1RHA5e24JI2nrVvhuONg332b4datfVc0L72FQ5J9geOAtwJU1d1VdRuwFji/Xex84Hl91CdJC2LdOrjySti2rRmuW9d3RfPSZ8vhUcA08PYk1yR5S5K9gJVVdTNAOzxw2JuTnJ5kU5JN09PTi1e1JN0XmzfDzEzzemamGR8DfYbDMuDJwF9U1ZOAO7kPXUhVtaGqpqpqasWKFaOqUZJ2zZo1sKx9IvOyZc34GOgzHLYAW6rq0+34xTRhsTXJwQDt8Nae6pOkXbdxIxxzDOyzTzPcuLHviuZlWV8brqpbktyY5LFV9WXgeOCL7b/1wLnt8JK+apSkXbZyJVxxRd9V3Gd9n630KuCCJJ8H1gC/SxMKJya5HjixHZek0RrTs4pGpbeWA0BVbQamhsw6fpFLkTTptp9VNDPzo7OKxnCPf6H03XKQpKVhTM8qGhXDQZJgbM8qGhXDQZJgbM8qGpVejzlI0pIxpmcVjYotB0lSh+EgSeowHCRJHYaDpPHjBWsjZzhIGj9jehvscWI4SBo/XrA2coaDpPHjBWsjZzhIGj9esDZyXgQnafx4wdrI2XKQJHUYDpKkDsNBktRhOEiSOgwHSVKH4SBJ6ug9HJLsluSaJH/bjh+Q5LIk17fD/fuuUdL95D2Qxlbv4QC8BrhuYPxM4PKqOgK4vB2XNI68B9LY6jUckqwCngO8ZWDyWuD89vX5wPMWuSxJC8V7II2tvlsOfwy8Hrh3YNrKqroZoB0eOOyNSU5PsinJpunp6ZEXKul+8B5IY6u3cEhyCnBrVV11f95fVRuqaqqqplasWLHA1UlaEN4DaWz1eW+lpwPPTfJsYA9g3yTvBrYmObiqbk5yMHBrjzVK2hXeA2ls9dZyqKqzqmpVVR0GnAp8vKpOAy4F1reLrQcu6alESZpYfR9zGOZc4MQk1wMntuOSpEW0JG7ZXVWfAD7Rvv42cHyf9UjSpFuKLQdJUs8MB0lSh+EgSeowHCRJHYaDJKnDcJAkdRgOkqQOw0GSz11Qh+EgyecuqMNwkMbJqPbwfe6CZjEcpHEyqj18n7ugWQwHaZyMag/f5y5oliVx4z1J87RmTdNimJlZ2D18n7ugWWw5SOPEPXwtEsNBGpVRHDzevod/xx3NcOXKXV+nNIThII2Kp4dqjBkOkqeHSh2Gg+TpoVKH4SB5eqjU0Vs4JDk0yT8kuS7JtUle004/IMllSa5vh/v3VaMmxKj28D14rDHWZ8thBnhtVT0eeBrwiiSrgTOBy6vqCODydlwaHffwpY7eLoKrqpuBm9vX25JcBxwCrAWe0S52PvAJ4IweStSk8AIwqWNJHHNIchjwJODTwMo2OLYHyIE7eM/pSTYl2TQ9Pb1otUrSJOg9HJLsDbwP+NWqumO+76uqDVU1VVVTK1asGF2BkjSBeg2HJA+mCYYLqur97eStSQ5u5x8M3NpXfZI0qfo8WynAW4HrqurNA7MuBda3r9cDlyx2bVqifFqZtGj6bDk8HXgx8Mwkm9t/zwbOBU5Mcj1wYjsueTsKaRH1ebbSPwHZwezjF7MWjQlvRyEtmt4PSOsBahRdQN6OQlo0hoNGYxRdQF6sJi0anwSn0RhFF5AXq0mLxpaDRsMuIGmsGQ4aDbuApLFmt5JGwy4gaazZcpAkdRgOkqQOw0GS1GE4TDrvVyRpCMNh0nm/IklDGA7jYlR7+N6vSNIQhsO4GNUevherSRrCcBgXo9rD92I1SUN4Edy4WLOmaTHMzCzsHr4Xq0kaYqcthySvTLL/YhSjObiHL2kRzaflcBDw2SRXA28DPlJVNdqy1OEevqRFtNOWQ1X9BnAEzfOefxG4PsnvJnn0iGuTJPVkXgek25bCLe2/GWB/4OIkbxpVYUlOTvLlJDckOXNU25Ekdc3nmMOrk1wFvAn4Z+Anq+plwE8D/20URSXZDfgz4FnAauAFSVaPYluSpK75HHNYDvzXqvr64MSqujfJKaMpi6OBG6rq3wCSXAisBb44ou1JkgbsNByq6jfnmHfdwpbzQ4cANw6MbwGeOqJtSZJmWaoXwWXItB87QyrJ6Uk2Jdk0PT29SGVJ0mRYquGwBTh0YHwVcNPgAlW1oaqmqmpqxYoVi1qcJD3QLdVw+CxwRJLDkzwEOBW4tOeaJGliLMnbZ1TVTJJXAh8BdgPeVlXX9lyWJE2MJRkOAFX1IeBDfdchSZNoqXYrSZJ6ZDhIkjoMB0lSh+EgSeowHCRJHYaDJKnDcJAkdRgOkqQOw0GS1GE4SJI6DAdJUofhIEnqMBwkSR2GgySpw3CQJHUYDpKkDsNBktRhOEiSOgwHSVJHL+GQ5A+SfCnJ55N8IMl+A/POSnJDki8nOamP+iRp0vXVcrgMOKqqngh8BTgLIMlq4FTgCcDJwJ8n2a2nGiVpYvUSDlX10aqaaUc/BaxqX68FLqyqu6rqq8ANwNF91ChJk2wpHHN4KfDh9vUhwI0D87a00zqSnJ5kU5JN09PTIy5RkibLslGtOMnHgIOGzDq7qi5plzkbmAEu2P62IcvXsPVX1QZgA8DU1NTQZSRJ98/IwqGqTphrfpL1wCnA8VW1/ct9C3DowGKrgJtGU6EkaUf6OlvpZOAM4LlV9b2BWZcCpybZPcnhwBHAZ/qoUZIm2chaDjvxp8DuwGVJAD5VVb9SVdcmuQj4Ik130yuq6p6eapSkidVLOFTVY+aYdw5wziKWI0maZSmcrSRJWmIMB0lSh+EgSeowHCRJHYaDJKnDcJAkdRgOkqQOw0GS1GE4SJI6DAdJUofhIEnqMBwkSR2GgySpw3CQJHUYDpKkDsNBktRhOEiSOgwHSVKH4SBJ6ug1HJK8LkklWT4w7awkNyT5cpKT+qxPkibVsr42nORQ4ETgGwPTVgOnAk8AHgF8LMmRVXVPP1VK0mTqs+XwR8DrgRqYtha4sKruqqqvAjcAR/dRnCRNsl7CIclzgW9W1edmzToEuHFgfEs7bdg6Tk+yKcmm6enpEVUqSZNpZN1KST4GHDRk1tnArwP/ZdjbhkyrIdOoqg3ABoCpqamhy0iS7p+RhUNVnTBsepKfBA4HPpcEYBVwdZKjaVoKhw4svgq4aVQ1SpKGW/Rupar6QlUdWFWHVdVhNIHw5Kq6BbgUODXJ7kkOB44APrPYNUrSpOvtbKVhquraJBcBXwRmgFd4ppIkLb7ew6FtPQyOnwOc0081kiTwCmlJ0hCGgySpw3CQJHUYDgts61Y47jjYd99muHVr3xVJ0n1nOCywdevgyith27ZmuG5d3xVJ0n1nOCywzZthZqZ5PTPTjEvSuDEcFtiaNbCsPUF42bJmXJLGjeGwwDZuhGOOgX32aYYbN/ZdkSTdd71fBPdAs3IlXHFF31VI0q6x5SBJ6jAcJEkdhoMkqcNwkCR1GA6SpA7DQZLUYThIkjoMB0lSh+EgSeowHCRJHb2FQ5JXJflykmuTvGlg+llJbmjnndRXfZI0yXq5t1KSnwXWAk+sqruSHNhOXw2cCjwBeATwsSRHVtU9fdQpSZOqr5bDy4Bzq+ougKq6tZ2+Friwqu6qqq8CNwBH91SjJE2svsLhSOA/Jfl0kn9M8pR2+iHAjQPLbWmndSQ5PcmmJJump6dHXK4kTZaRdSsl+Rhw0JBZZ7fb3R94GvAU4KIkjwIyZPkatv6q2gBsAJiamhq6jCTp/hlZOFTVCTual+RlwPurqoDPJLkXWE7TUjh0YNFVwE2jqlGSNFxf3UofBJ4JkORI4CHAt4BLgVOT7J7kcOAI4DM91ShJE6uvJ8G9DXhbkn8B7gbWt62Ia5NcBHwRmAFe4ZlKkrT4egmHqrobOG0H884BzlnciiRJg7xCWpLUYThIkjomOhy2boXjjoN9922GW7f2XZEkLQ0THQ7r1sGVV8K2bc1w3bq+K5KkpWGiw2HzZpiZaV7PzDTjkqQJD4c1a2BZe77WsmXNuCRpwsNh40Y45hjYZ59muHFj3xVJ0tLQ10VwS8LKlXDFFX1XIUlLz0S3HCRJwxkOkqQOw0GS1GE4SJI6DAdJUofhIEnqSPMYhfGWZBr4+i6sYjnNw4bGwTjVCuNVr7WOzjjVO061wq7V+8iqWjFsxgMiHHZVkk1VNdV3HfMxTrXCeNVrraMzTvWOU60wunrtVpIkdRgOkqQOw6Gxoe8C7oNxqhXGq15rHZ1xqnecaoUR1esxB0lShy0HSVKH4SBJ6pjocEhycpIvJ7khyZl91zOXJIcm+Yck1yW5Nslr+q5pZ5LsluSaJH/bdy07k2S/JBcn+VL7Mz6m75p2JMmvtb8D/5LkvUn26LumQUneluTWJP8yMO2AJJclub4d7t9njdvtoNY/aH8PPp/kA0n267HEHzOs3oF5r0tSSZYvxLYmNhyS7Ab8GfAsYDXwgiSr+61qTjPAa6vq8cDTgFcs8XoBXgNc13cR83Qe8PdV9Tjgp1iidSc5BHg1MFVVRwG7Aaf2W1XHO4CTZ007E7i8qo4ALm/Hl4J30K31MuCoqnoi8BXgrMUuag7voFsvSQ4FTgS+sVAbmthwAI4Gbqiqf6uqu4ELgbU917RDVXVzVV3dvt5G8+V1SL9V7ViSVcBzgLf0XcvOJNkXOA54K0BV3V1Vt/Va1NyWAQ9NsgzYE7ip53p+TFVdAXxn1uS1wPnt6/OB5y1mTTsyrNaq+mhVtU+X51PAqkUvbAd28LMF+CPg9cCCnWE0yeFwCHDjwPgWlvCX7aAkhwFPAj7dcylz+WOaX9Z7e65jPh4FTANvb7vB3pJkr76LGqaqvgn8Ic0e4s3A7VX10X6rmpeVVXUzNDs6wIE91zNfLwU+3HcRc0nyXOCbVfW5hVzvJIdDhkxb8uf1JtkbeB/wq1V1R9/1DJPkFODWqrqq71rmaRnwZOAvqupJwJ0snW6PH9P21a8FDgceAeyV5LR+q3pgSnI2TXfuBX3XsiNJ9gTOBn5zodc9yeGwBTh0YHwVS6x5PluSB9MEwwVV9f6+65nD04HnJvkaTXfdM5O8u9+S5rQF2FJV21tiF9OExVJ0AvDVqpquqh8A7weO7bmm+dia5GCAdnhrz/XMKcl64BTgRbW0LwZ7NM2Owufav7dVwNVJDtrVFU9yOHwWOCLJ4UkeQnNQ79Kea9qhJKHpE7+uqt7cdz1zqaqzqmpVVR1G83P9eFUt2b3bqroFuDHJY9tJxwNf7LGkuXwDeFqSPdvfieNZogfPZ7kUWN++Xg9c0mMtc0pyMnAG8Nyq+l7f9cylqr5QVQdW1WHt39sW4Mnt7/QumdhwaA84vRL4CM0f10VVdW2/Vc3p6cCLafbCN7f/nt13UQ8grwIuSPJ5YA3wu/2WM1zburkYuBr4As3f8JK63UOS9wJXAo9NsiXJLwHnAicmuZ7mrJpz+6xxux3U+qfAPsBl7d/ZX/Za5IAd1DuabS3tFpMkqQ8T23KQJO2Y4SBJ6jAcJEkdhoMkqcNwkCR1GA6SpA7DQZLUYThII5DkKe3zAPZIslf7/IWj+q5Lmi8vgpNGJMnvAHsAD6W5d9Pv9VySNG+GgzQi7T27Pgv8B3BsVd3Tc0nSvNmtJI3OAcDeNPfpWVKP8pR2xpaDNCJJLqW5ZfnhwMFV9cqeS5LmbVnfBUgPREleAsxU1Xva55V/Mskzq+rjfdcmzYctB0lSh8ccJEkdhoMkqcNwkCR1GA6SpA7DQZLUYThIkjoMB0lSx/8HxhBa9dXPOdUAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "means = [[ 32.98489643 -57.51107027]]\n",
      "covariances = [[429.45764867  90.24987882]]\n",
      "weights =  [[0.86682762 0.13317238]]\n"
     ]
    }
   ],
   "source": [
    "for i in range(0,len(labels)):\n",
    "    if labels[i] == 0:\n",
    "        plt.scatter(i, data.take(i), s=15, c='red')\n",
    "    elif labels[i] == 1:\n",
    "        plt.scatter(i, data.take(i), s=15, c='blue')\n",
    "plt.title('Gaussian Mixture Model')\n",
    "plt.xlabel('x')\n",
    "plt.ylabel('y')\n",
    "plt.show()\n",
    "print(\"means =\", gmmModel.means_.reshape(1, -1))\n",
    "print(\"covariances =\", gmmModel.covariances_.reshape(1, -1))\n",
    "print(\"weights = \", gmmModel.weights_.reshape(1, -1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 习题9.4\n",
    "&emsp;&emsp;EM算法可以用到朴素贝叶斯法的非监督学习，试写出其算法。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**解答：** \n",
    "> **EM算法的一般化：**  \n",
    "**E步骤：**根据参数初始化或上一次迭代的模型参数来计算出隐变量的后验概率，其实就是隐变量的期望。作为隐变量的现估计值：$$w_j^{(i)}=Q_{i}(z^{(i)}=j) := p(z^{(i)}=j | x^{(i)} ; \\theta)$$\n",
    "**M步骤：**将似然函数最大化以获得新的参数值：$$\n",
    "\\theta :=\\arg \\max_{\\theta} \\sum_i \\sum_{z^{(i)}} Q_i (z^{(i)}) \\log \\frac{p(x^{(i)}, z^{(i)} ; \\theta)}{Q_i (z^{(i)})}\n",
    "$$  \n",
    "\n",
    "使用NBMM（朴素贝叶斯的混合模型）中的$\\phi_z,\\phi_{j|z^{(i)}=1},\\phi_{j|z^{(i)}=0}$参数替换一般化的EM算法中的$\\theta$参数，然后依次求解$w_j^{(i)}$与$\\phi_z,\\phi_{j|z^{(i)}=1},\\phi_{j|z^{(i)}=0}$参数的更新问题。  \n",
    "**NBMM的EM算法：**  \n",
    "**E步骤：**  \n",
    "$$w_j^{(i)}:=P\\left(z^{(i)}=1 | x^{(i)} ; \\phi_z, \\phi_{j | z^{(i)}=1}, \\phi_{j | z^{(i)}=0}\\right)$$**M步骤：**$$\n",
    "\\phi_{j | z^{(i)}=1} :=\\frac{\\displaystyle \\sum_{i=1}^{m} w^{(i)} I(x_{j}^{(i)}=1)}{\\displaystyle \\sum_{i=1}^{m} w^{(i)}} \\\\ \n",
    "\\phi_{j | z^{(i)}=0}:= \\frac{\\displaystyle  \\sum_{i=1}^{m}\\left(1-w^{(i)}\\right) I(x_{j}^{(i)}=1)}{ \\displaystyle \\sum_{i=1}^{m}\\left(1-w^{(i)}\\right)} \\\\ \n",
    "\\phi_{z^{(i)}} :=\\frac{\\displaystyle \\sum_{i=1}^{m} w^{(i)}}{m} \n",
    "$$   "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
